# code to estimate hierarhical LASSO
# DV: network ideal point
# IVs: average document scores estimated using pivot scaling

library(tidyverse)
library(data.table)
library(dtplyr)
library(network)

library(hierNet)

require(foreach)
require(doMC)
registerDoMC(cores = 10)

source("pundits_functions.R")

set.seed(1111)

# function to routinize fitting hierarhical lasso
run_routine <- function(input, 
                        outcome,
                        lamlist = NULL,
                        save_path = "../output/",
                        name,
                        includevars,
                        seed = 11111){
  require(hierNet)
  set.seed(seed)
  
  message('shaping input')
  X <- input %>%
    dplyr::select(all_of(includevars)) %>%
    as.matrix()
  
  X_scale <- as.matrix(data.frame(scale(X)))
  
  Y <- outcome
  if(any(is.na(Y))){
    message(paste0("removing ", sum(is.na(Y)), " observations due to missingness"))
    Y <- Y[-which(is.na(Y))]
    X_scale <- X_scale[-which(is.na(Y)),]
  }
  
  message("modeling...")
  if(is.null(lamlist)){
       mod <- hierNet.path(y = Y,
                        x = X_scale,
                        nlam = 20,
                        strong = TRUE,
                        trace = 1)
  }else{
   mod <- hierNet.path(y = Y,
                        x = X_scale,
                        lamlist = lamlist,
                        strong = TRUE,
                        trace = 1)
  }
 
    
  save(mod, file = paste0(save_path, "mod_", name, ".RData"))
}


pundits_agg0 <- readr::read_csv("../data/pundits_parrot_aggregated_imp0.csv")
                                            
ideal.points_d1 <- readr::read_csv("../data/ideal_points_masked.csv")
      
which.na.ts <- which(is.na(ideal.points_d1$tweetscore))

message("fitting paths")
run_routine(input = pundits_agg0,
             outcome = ideal.points_d1$ideal.point,
             includevars = names(pundits_agg0)[-1],
             lamlist = seq(from = 10, to = 200, by = 10),
            name = "allvars")

run_routine(input = pundits_agg0[-which.na.ts,],
             outcome = ideal.points_d1$tweetscore[-which.na.ts],
             includevars = names(pundits_agg0)[-1],
            lamlist = seq(from = 10, to = 200, by = 10),
            name = "allvars_tweetscore")